{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# environment setting\n",
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "# this should be learning_models\n",
    "curr_dir = os.getcwd()\n",
    "\n",
    "# this should be pyhealth, which is two level up from learning_models library\n",
    "root_dir = Path(curr_dir).parents[1]\n",
    "os.chdir(root_dir)\n",
    "\n",
    "sys.path.append(root_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 93.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current ExpData_ID: 2020.0810.image --- Target for CMS\n",
      "finished X generate\n",
      "finished Y generate\n",
      "generate finished\n",
      "target Task: diagnose\n",
      "N of labels: 3\n",
      "N of TrainData: 64\n",
      "N of ValidData: 16\n",
      "N of TestData: 20\n",
      "load finished\n",
      "target Task: diagnose\n",
      "N of labels: 3\n",
      "N of TrainData: 64\n",
      "N of ValidData: 16\n",
      "N of TestData: 20\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.data.expdata_generator import imagedata as expdata_generator\n",
    "\n",
    "# override here to specify where the data locates\n",
    "# root_dir = ''\n",
    "data_dir = os.path.join(root_dir, 'datasets', 'image')\n",
    "\n",
    "expdata_id = '2020.0810.image'\n",
    "\n",
    "# set up the datasets\n",
    "cur_dataset = expdata_generator(expdata_id, root_dir=root_dir)\n",
    "cur_dataset.get_exp_data(sel_task='diagnose', data_root=data_dir)\n",
    "cur_dataset.load_exp_data()\n",
    "# cur_dataset.show_data()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use 2 GPUs recource\n",
      "current task can beed seen as multiclass\n",
      "=> creating model 'resnet18'\n",
      "    Total params: 11.18M\n",
      "=> creating model 'resnet18'\n",
      "    Total params: 11.18M\n",
      "load best-th epoch model\n",
      "current task can beed seen as multiclass\n",
      "{'hat_y': array([[0.34585968, 0.34950513, 0.30463523],\n",
      "       [0.3371851 , 0.339385  , 0.32342988],\n",
      "       [0.33524728, 0.3321462 , 0.3326065 ],\n",
      "       [0.34149212, 0.32257894, 0.3359289 ],\n",
      "       [0.35077342, 0.3396097 , 0.30961686],\n",
      "       [0.35468316, 0.32482758, 0.32048926],\n",
      "       [0.35699597, 0.33426023, 0.3087438 ],\n",
      "       [0.35296798, 0.33218038, 0.31485164],\n",
      "       [0.33812657, 0.34893402, 0.3129394 ],\n",
      "       [0.33851597, 0.341828  , 0.31965604],\n",
      "       [0.35516188, 0.3376512 , 0.307187  ],\n",
      "       [0.34562922, 0.34481993, 0.30955082],\n",
      "       [0.34262735, 0.34716234, 0.3102103 ],\n",
      "       [0.34689385, 0.3238576 , 0.32924852],\n",
      "       [0.3396384 , 0.33442318, 0.3259385 ],\n",
      "       [0.35435236, 0.32238114, 0.32326648],\n",
      "       [0.32515126, 0.35782436, 0.31702438],\n",
      "       [0.34224582, 0.34261426, 0.31513995],\n",
      "       [0.3434586 , 0.33913156, 0.3174099 ],\n",
      "       [0.33704153, 0.32978615, 0.33317235]], dtype=float32), 'y': array([[0., 1., 0.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 0., 1.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 1., 0.],\n",
      "       [1., 0., 0.],\n",
      "       [1., 0., 0.],\n",
      "       [1., 0., 0.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 0., 1.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 0., 1.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 0., 1.],\n",
      "       [0., 0., 1.],\n",
      "       [1., 0., 0.],\n",
      "       [0., 1., 0.],\n",
      "       [0., 0., 1.],\n",
      "       [0., 1., 0.]], dtype=float32)}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tr=>epoch=99 Valid Loss: 21.487, Train Loss: 21.485: 100%|██████████| 100/100 [02:46<00:00,  1.66s/it]\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.models.image.typicalcnn import TypicalCNN as model\n",
    "# initialize the model for training\n",
    "expmodel_id = '2020.0810.cnn.image.diagnose.'\n",
    "clf = model(expmodel_id=expmodel_id, n_epoch=100, use_gpu=True,\n",
    "            gpu_ids='0,1')\n",
    "clf.fit(cur_dataset.train, cur_dataset.valid)\n",
    "\n",
    "# load the best model for inference\n",
    "clf.load_model()\n",
    "clf.inference(cur_dataset.test)\n",
    "results = clf.get_results()\n",
    "print(results)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current data evaluate using multiclass evaluation-type\n",
      "{'avg_precision_micro': 0.4080676029940976, 'roc_auc_score_micro': 0.5475, 'coverage_error': 1.95, 'label_ranking_average_precision_score': 0.6416666666666668, 'label_ranking_loss': 0.475, 'hamming_loss@1': 0.4, 'recall@1': 0.4, 'precision@1': 0.4, 'hamming_loss@3': 0.6666666666666666, 'recall@3': 1.0, 'precision@3': 0.33333333333333326}\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.evaluation.evaluator import func\n",
    "# evaluate the model\n",
    "r = func(results['hat_y'], results['y'])\n",
    "print(r)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}